Skip to content

Add optional Gated Attention#22372

Open
pass-lin wants to merge 11 commits intokeras-team:masterfrom
pass-lin:master
Open

Add optional Gated Attention#22372
pass-lin wants to merge 11 commits intokeras-team:masterfrom
pass-lin:master

Conversation

@pass-lin
Copy link
Contributor

@pass-lin pass-lin commented Mar 6, 2026

from #22337

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces an optional gated attention mechanism to the GroupedQueryAttention and MultiHeadAttention layers. This new feature aims to enhance the stability of attention models during training and mitigate issues like "attention sinks" by dynamically controlling the flow of information through a sigmoid-activated gate, as suggested by recent research.

Highlights

  • New Feature: Gated Attention: Introduced an optional use_gate parameter to GroupedQueryAttention and MultiHeadAttention layers, enabling a gated attention mechanism.
  • Implementation of Gated Attention: Implemented the gated attention logic, which applies a sigmoid-activated linear projection to the query to gate the attention output, aiming to improve training stability and eliminate "attention sinks" as described in the docstring.
  • Comprehensive Testing: Added comprehensive unit tests for both GroupedQueryAttention and MultiHeadAttention to cover the new use_gate functionality, including basic functionality, shape computation, initializer behavior, masking, constraints, LoRA, and quantization.
  • Code Formatting: Performed minor code formatting adjustments in backend trainer.py, core.py files, and layer.py for improved readability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/jax/trainer.py
    • Reformated a lambda expression for improved readability.
  • keras/src/backend/numpy/core.py
    • Reformated a lambda expression for improved readability.
  • keras/src/backend/torch/core.py
    • Reformated a lambda expression for improved readability.
  • keras/src/layers/attention/grouped_query_attention.py
    • Added use_gate parameter to the layer's constructor and configuration.
    • Implemented the conditional creation of a _gate_dense layer.
    • Modified the call method to apply the gating mechanism to the attention output.
  • keras/src/layers/attention/grouped_query_attention_test.py
    • Added new test cases to verify the functionality of GroupedQueryAttention when use_gate is enabled.
    • Extended existing tests for compute_output_shape, initializer, query_mask_propagation, and masking to include use_gate=True scenarios.
  • keras/src/layers/attention/multi_head_attention.py
    • Added _use_gate attribute to the layer's constructor and configuration.
    • Implemented the conditional creation of a _gate_dense layer in the build method.
    • Modified the call method to apply the gating mechanism to the attention output.
  • keras/src/layers/attention/multi_head_attention_test.py
    • Added new test cases to verify the functionality of MultiHeadAttention when use_gate is enabled.
    • Extended existing tests for basics, high_dim_attention, attention_axes_negative_indexing, compute_output_shape, initializer, mha_constraints, lora, and quantize_int8 to include use_gate=True scenarios.
  • keras/src/layers/layer.py
    • Reformated lambda expressions within the _initialize_tracker method for improved readability.
Activity
  • No specific activity (comments, reviews, progress updates) was provided in the context for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optional gated attention mechanism to GroupedQueryAttention and MultiHeadAttention. While the feature is a good addition, the implementation has some critical issues. The docstrings reference a non-existent paper, which is misleading and violates the repository's contribution guidelines. Additionally, there is a bug in the MultiHeadAttention gating logic that will cause a runtime error when key_dim and value_dim differ. I've provided detailed comments and suggestions to address these issues.

Note: Security Review did not run due to the size of the PR.

bias_constraint: Constraint for dense layer kernels.
use_gate: Boolean, whether to apply a gated attention mechanism.
When True, an additional gating branch is added based on the
(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reference to "(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]" appears to be a placeholder. The year is in the future and the arXiv link is invalid. This violates the Keras API design guidelines, which require new features to be based on widely recognized best practices. Please replace this with a valid reference to the paper that introduced this gated attention mechanism, or provide a more general explanation of the technique if a specific paper isn't the source.

References
  1. New features should be widely recognized as a machine learning best practice and not based on very recent or non-existent papers. (link)

bias_constraint: Constraint for dense layer kernels.
use_gate: Boolean, whether to apply a gated attention mechanism.
When True, an additional gating branch is added based on the
(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reference to "(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]" appears to be a placeholder. The year is in the future and the arXiv link is invalid. This violates the Keras API design guidelines, which require new features to be based on widely recognized best practices. Please replace this with a valid reference to the paper that introduced this gated attention mechanism, or provide a more general explanation of the technique if a specific paper isn't the source.

References
  1. New features should be widely recognized as a machine learning best practice and not based on very recent or non-existent papers. (link)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link is fine, I'm not sure what Gemini is looking at. However, can you change the title from NeurIPS 2025 Best Paper to Gated Attention for Large Language Models?

Comment on lines +283 to +294
if self._use_gate:
self._gate_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(key_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of the gated attention has a few issues that will lead to errors or incorrect behavior:

  1. Shape Mismatch Bug: The _gate_dense layer's output dimension is set to self._key_dim, but it's multiplied with attention_output, which has a dimension of self._value_dim. This will cause a runtime error if key_dim != value_dim.
  2. Incorrect Equation: The einsum_equation used is from the _key_dense layer, but the gate is a projection of the query. It should use an equation based on the query's rank.
  3. Incorrect Build Shape: The _gate_dense layer is built with key_shape, but it should be built with query_shape since it processes the query.

To fix these issues, the _gate_dense layer should be defined using the query's properties and its output dimension should be self._value_dim. This may require refactoring the build method slightly to define _gate_dense after _query_dense to reuse its equation variables.

Here is a suggestion to fix the most critical parts (the shape mismatch and incorrect build shape):

Suggested change
if self._use_gate:
self._gate_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(key_shape)
if self._use_gate:
query_einsum_equation, query_bias_axes, query_output_rank = _build_proj_equation(
query_rank - 1, bound_dims=1, output_dims=2
)
self._gate_dense = EinsumDense(
query_einsum_equation,
output_shape=_get_output_shape(
query_output_rank - 1, [self._num_heads, self._value_dim]
),
bias_axes=query_bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(query_shape)

name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(key_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _gate_dense layer is a projection of the query, but it's being built using key_shape. While this may not cause an error if EinsumDense only relies on the feature dimension (which is the same for query and key), it's semantically incorrect and confusing for future maintenance. For clarity and correctness, please build this layer using query_shape.

Suggested change
self._gate_dense.build(key_shape)
self._gate_dense.build(query_shape)

@pass-lin
Copy link
Contributor Author

pass-lin commented Mar 6, 2026

/gemini review

@codecov-commenter
Copy link

codecov-commenter commented Mar 6, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 83.04%. Comparing base (340b110) to head (834d7cb).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #22372   +/-   ##
=======================================
  Coverage   83.04%   83.04%           
=======================================
  Files         596      596           
  Lines       66708    66725   +17     
  Branches    10384    10390    +6     
=======================================
+ Hits        55395    55412   +17     
  Misses       8676     8676           
  Partials     2637     2637           
Flag Coverage Δ
keras 82.87% <100.00%> (+<0.01%) ⬆️
keras-jax 60.54% <100.00%> (+0.01%) ⬆️
keras-numpy 54.78% <100.00%> (+0.01%) ⬆️
keras-openvino 49.96% <100.00%> (+0.01%) ⬆️
keras-tensorflow 61.79% <100.00%> (+<0.01%) ⬆️
keras-torch 60.61% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@pass-lin
Copy link
Contributor Author

pass-lin commented Mar 6, 2026

I need some help running this test.
@fchollet

@parameterized.named_parameters(
    ("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)),
    ("4d_inputs_1freebatch_mask3", (3, 4), (3, 2), (3, 4, 2), (2,)),
    ("4d_inputs_1freebatch_mask4", (3, 4), (3, 2), (3, 2, 4, 2), (2,)),
    ("4d_inputs_2d_attention", (3, 4), (3, 2), (3, 4, 3, 2), (1, 2)),
    ("5d_inputs_2d_attention", (5, 3, 4), (5, 3, 2), (3, 4, 3, 2), (2, 3)),
    (
        "5d_inputs_2d_attention_fullmask",
        (5, 3, 4),
        (5, 3, 2),
        (5, 3, 4, 3, 2),
        (2, 3),
    ),
)
def test_high_dim_attention(
    self, q_dims, v_dims, mask_dims, attention_axes
):
    batch_size, hidden_size = 3, 8
    query_shape = (batch_size,) + q_dims + (hidden_size,)
    value_shape = (batch_size,) + v_dims + (hidden_size,)
    self.run_layer_test(
        layers.MultiHeadAttention,
        init_kwargs={
            "num_heads": 2,
            "key_dim": 2,
            "attention_axes": attention_axes,
        },
        input_shape={
            "query_shape": query_shape,
            "value_shape": value_shape,
        },
        expected_output_shape=query_shape,
        expected_num_trainable_weights=8,
        expected_num_non_trainable_weights=0,
        expected_num_seed_generators=0,
        expected_num_losses=0,
        supports_masking=True,
        run_training_check=False,
    )

    self.run_layer_test(
        layers.MultiHeadAttention,
        init_kwargs={
            "num_heads": 2,
            "key_dim": 2,
            "use_gate": True,
            "attention_axes": attention_axes,
        },
        input_shape={
            "query_shape": query_shape,
            "value_shape": value_shape,
        },
        expected_output_shape=query_shape,
        expected_num_trainable_weights=10,
        expected_num_non_trainable_weights=0,
        expected_num_seed_generators=0,
        expected_num_losses=0,
        supports_masking=True,
        run_training_check=False,
    )

The error occurs only with the OpenVINO backend:

FAILED keras/src/layers/attention/multi_head_attention_test.py::MultiHeadAttentionTest::test_high_dim_attention_4d_inputs_1freebatch_mask3 - ValueError: expected non-negative integer

I discovered that running this part works fine:

batch_size, hidden_size = 3, 8
query_shape = (batch_size,) + q_dims + (hidden_size,)
value_shape = (batch_size,) + v_dims + (hidden_size,)
self.run_layer_test(
    layers.MultiHeadAttention,
    init_kwargs={
        "num_heads": 2,
        "key_dim": 2,
        "attention_axes": attention_axes,
    },
    input_shape={
        "query_shape": query_shape,
        "value_shape": value_shape,
    },
    expected_output_shape=query_shape,
    expected_num_trainable_weights=8,
    expected_num_non_trainable_weights=0,
    expected_num_seed_generators=0,
    expected_num_losses=0,
    supports_masking=True,
    run_training_check=False,
)

But when I run the following, it fails:

self.run_layer_test(
    layers.MultiHeadAttention,
    init_kwargs={
        "num_heads": 2,
        "key_dim": 2,
        "attention_axes": attention_axes,
    },
    input_shape={
        "query_shape": query_shape,
        "value_shape": value_shape,
    },
    expected_output_shape=query_shape,
    expected_num_trainable_weights=10,
    expected_num_non_trainable_weights=0,
    expected_num_seed_generators=0,
    expected_num_losses=0,
    supports_masking=True,
    run_training_check=False,
)

The error is always:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[9], line 4
      2 query_shape = (batch_size,) + q_dims + (hidden_size,)
      3 value_shape = (batch_size,) + v_dims + (hidden_size,)
----> 4 self.run_layer_test(
      5     layers.MultiHeadAttention,
      6     init_kwargs={
      7         "num_heads": 2,
      8         "key_dim": 2,
      9         "attention_axes": attention_axes,
     10     },
     11     input_shape={
     12         "query_shape": query_shape,
     13         "value_shape": value_shape,
     14     },
     15     expected_output_shape=query_shape,
     16     expected_num_trainable_weights=8,
     17     expected_num_non_trainable_weights=0,
     18     expected_num_seed_generators=0,
     19     expected_num_losses=0,
     20     supports_masking=True,
     21     run_training_check=False,
     22 )

File /mnt/d/keras/keras/src/testing/test_case.py:320, in TestCase.run_layer_test(self, layer_cls, init_kwargs, input_shape, input_dtype, input_sparse, input_ragged, input_data, call_kwargs, expected_output_shape, expected_output_dtype, expected_output_sparse, expected_output_ragged, expected_output, expected_num_trainable_weights, expected_num_non_trainable_weights, expected_num_non_trainable_variables, expected_num_seed_generators, expected_num_losses, supports_masking, expected_mask_shape, custom_objects, run_training_check, run_mixed_precision_check, assert_built_after_instantiation, tpu_atol, tpu_rtol)
    318 if input_data is not None or input_shape is not None:
    319     if input_data is None:
--> 320         input_data = create_eager_tensors(
    321             input_shape, input_dtype, input_sparse, input_ragged
    322         )
    323     layer = layer_cls(**init_kwargs)
    324     if isinstance(input_data, dict):

File /mnt/d/keras/keras/src/testing/test_case.py:778, in create_eager_tensors(input_shape, dtype, sparse, ragged)
    773         return ops.cast(
    774             random.uniform(shape, dtype="float32") * 3, dtype=dt
    775         )
    777 if isinstance(input_shape, dict):
--> 778     return {
    779         utils.removesuffix(k, "_shape"): create_fn(v, dtype[k])
    780         for k, v in input_shape.items()
    781     }
    782 return map_shape_dtype_structure(create_fn, input_shape, dtype)

File /mnt/d/keras/keras/src/testing/test_case.py:779, in <dictcomp>(.0)
    773         return ops.cast(
    774             random.uniform(shape, dtype="float32") * 3, dtype=dt
    775         )
    777 if isinstance(input_shape, dict):
    778     return {
--> 779         utils.removesuffix(k, "_shape"): create_fn(v, dtype[k])
    780         for k, v in input_shape.items()
    781     }
    782 return map_shape_dtype_structure(create_fn, input_shape, dtype)

File /mnt/d/keras/keras/src/testing/test_case.py:774, in create_eager_tensors.<locals>.create_fn(shape, dt)
    772 def create_fn(shape, dt):
    773     return ops.cast(
--> 774         random.uniform(shape, dtype="float32") * 3, dtype=dt
    775     )

File /mnt/d/keras/keras/src/backend/openvino/random.py:31, in uniform(shape, minval, maxval, dtype, seed)
     29 else:
     30     seed_data = seed_val.data
---> 31 rng = np.random.default_rng(seed_data)
     32 random_values = rng.uniform(minval, maxval, size=shape).astype(dtype)
     33 return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0))

File numpy/random/_generator.pyx:4957, in numpy.random._generator.default_rng()

File _pcg64.pyx:123, in numpy.random._pcg64.PCG64.__init__()

File bit_generator.pyx:535, in numpy.random.bit_generator.BitGenerator.__init__()

File bit_generator.pyx:315, in numpy.random.bit_generator.SeedSequence.__init__()

File bit_generator.pyx:389, in numpy.random.bit_generator.SeedSequence.get_assembled_entropy()

File bit_generator.pyx:148, in numpy.random.bit_generator._coerce_to_uint32_array()

File bit_generator.pyx:140, in numpy.random.bit_generator._coerce_to_uint32_array()

File bit_generator.pyx:70, in numpy.random.bit_generator._int_to_uint32_array()

ValueError: expected non-negative integer

I tried modifying the code:

def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
    dtype = dtype or floatx()
    seed_val = draw_seed(seed)
    if isinstance(seed_val, OpenVINOKerasTensor):
        seed_data = convert_to_numpy(seed_val)
    else:
        seed_data = seed_val.data
    print(seed_data)
    rng = np.random.default_rng(seed_data)
    random_values = rng.uniform(minval, maxval, size=shape).astype(dtype)
    return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0))

I printed seed_data and got [740428769 0] for the working case, and [-359424 5387] for the failing case.

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optional gated attention mechanism to the GroupedQueryAttention and MultiHeadAttention layers, a useful feature for improving training stability. The implementation is sound and accompanied by thorough tests. My review focuses on correcting the academic citations in the docstrings, which currently contain inaccuracies and appear to be hallucinations. The remaining changes are minor stylistic adjustments that improve code formatting.

@keras-team keras-team deleted a comment from gemini-code-assist bot Mar 6, 2026
@keras-team keras-team deleted a comment from gemini-code-assist bot Mar 6, 2026
@MarcosAsh
Copy link
Contributor

Hey, the OpenVINO failure isn't related to your gated attention changes it's a pre-existing issue with the OpenVINO random backend. The issue is in keras/src/backend/openvino/random.py the uniform function passes seed values directly to np.random.default_rng(), which requires non-negative integers. When two run_layer_test calls run in the same test, the seed state can wrap to negative values (like the [-359424, 5387] you printed), causing the ValueError.

You should be able to fix it by wrapping the seed data with np.abs():

rng = np.random.default_rng(np.abs(seed_data))

The normal and truncated_normal functions in the same file (lines 19 and 133) have the same issue with seed.data so those should probably be fixed too.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Comment on lines +293 to +319
# Create two layers with equivalent positive and negative indices
mha_pos = layers.MultiHeadAttention(
num_heads=2, key_dim=4, attention_axes=2, use_gate=True
)
mha_neg = layers.MultiHeadAttention(
num_heads=2, key_dim=4, attention_axes=-2, use_gate=True
)

# Initialize both layers
_ = mha_pos(x, x)
_ = mha_neg(x, x)

# Set same weights for fair comparison
mha_neg.set_weights(mha_pos.get_weights())

# Get outputs and attention scores
z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)

# Verify shapes match
self.assertEqual(z_pos.shape, z_neg.shape)
self.assertEqual(a_pos.shape, a_neg.shape)

# Verify outputs are identical
self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5)
self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is particularly useful. I believe if it works with use_gate=False, it will work with use_gate=True.

However, I'd like to see a test that validates that use_gate=True actually does something different. Maybe by creating 2 layers, one with use_gate=False and one with use_gate=True and comparing their outputs. Although that is a weak verification, maybe you can think of something better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is particularly useful. I believe if it works with use_gate=False, it will work with use_gate=True.

However, I'd like to see a test that validates that use_gate=True actually does something different. Maybe by creating 2 layers, one with use_gate=False and one with use_gate=True and comparing their outputs. Although that is a weak verification, maybe you can think of something better.

I think this test can be kept to verify whether the use_gate = True workflow works properly.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase, this will remove the formatting changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants